Image Embeddings¶
- Generate image embeddings: Use a pre-trained neural network model to extract features from real-world images.
- Image Embedding: Use a pre-trained model like ResNet50 (from TensorFlow or PyTorch) to extract embeddings for each image.
- Store embeddings in FAISS: Save these embeddings in a FAISS index for efficient similarity search.
- Store the image embeddings in a FAISS index.
- Each image is loaded, resized, and processed to generate an embedding vector using the generate_image_embedding function. This function utilizes ResNet50's convolutional layers to extract meaningful features.
- The embeddings for all images are stored in a FAISS index. FAISS is an efficient similarity search library, which allows fast retrieval of similar vectors using techniques like L2 (Euclidean) distance.
- Implement RAG: Query the stored embeddings based on a user query and retrieve the top-k most similar images.
- implement a simple mechanism RAG to take a user query (which can be text) and retrieve the top-k most relevant images based on their embeddings.
- When a user provides a query image, we generate its embedding and search the FAISS index for the top-k most similar images based on the L2 distance. The function query_faiss_index handles this.
- Outputs:
- The first table shows the original images, their indices, and the first 5 values of their embeddings.
- The second table shows the predicted top-k images based on the user query, along with their indices and embeddings.
- The images themselves are displayed alongside the text data.
In [ ]:
%pip install -q tensorflow keras-resnet faiss-cpu pandas numpy matplotlib
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages. Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
#### best 2 images
import numpy as np
import pandas as pd
import faiss
import os
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input
# Load ResNet50 model pre-trained on ImageNet (without top classification layer)
def load_resnet50_model():
model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
return model
# Generate image embedding using ResNet50
def generate_image_embedding(img_path, model):
img = image.load_img(img_path, target_size=(224, 224)) # Resize image to (224, 224)
img_array = image.img_to_array(img) # Convert image to array
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
img_array = preprocess_input(img_array) # Preprocess for ResNet50
embedding = model.predict(img_array) # Get embedding
return embedding.flatten() # Flatten to 1D vector
# Create FAISS index to store embeddings
def create_faiss_index(embeddings):
dim = embeddings.shape[1] # Dimensionality of the embeddings
index = faiss.IndexFlatL2(dim) # Use L2 distance (Euclidean distance)
index.add(embeddings) # Add embeddings to the FAISS index
return index
# Query FAISS index and retrieve top k similar images
def query_faiss_index(query_embedding, index, k=2):
distances, indices = index.search(query_embedding.reshape(1, -1), k) # Get top-k neighbors
return distances, indices
# Display image grid for top-k matches
def display_images(images, titles, embeddings, indices):
fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
for i, ax in enumerate(axes):
ax.imshow(images[i])
ax.set_title(f"Index: {indices[i]}\nEmbedding: {embeddings[i][:5]}")
ax.axis('off')
plt.tight_layout()
plt.show()
# Display table with image metadata (index and embeddings)
def display_image_table(data, title):
df = pd.DataFrame(data, columns=['Index', 'Image', 'Embedding'])
print(f"\n{title}")
print(df.to_markdown(index=False)) # Use markdown to get nice table formatting
# Main function to integrate the process
def main():
# Load ResNet50 model
model = load_resnet50_model()
# Path to images directory (use user input or predefined path)
image_directory = '/curated/ImageStore/images/'
# Check if the directory exists
if not os.path.exists(image_directory):
print(f"Error: The directory '{image_directory}' does not exist.")
return
# List all image files in the directory and filter by extensions
image_files = [f for f in os.listdir(image_directory) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Construct full image paths
image_paths = [os.path.join(image_directory, img) for img in image_files]
# Generate embeddings for the images
embeddings = []
for img_path in image_paths:
embedding = generate_image_embedding(img_path, model)
embeddings.append(embedding)
embeddings = np.array(embeddings) # Convert list to array for FAISS
# Create FAISS index
index = create_faiss_index(embeddings)
# Display stored images, their index, and embeddings in a table
stored_image_data = [(i, os.path.basename(image_paths[i]), embeddings[i][:5]) for i in range(len(image_paths))]
display_image_table(stored_image_data, "Stored Images and Embeddings")
# User query image: get the image name from input and add to the directory path
user_query_image_name = 'elephant1.png' # Hardcoded query image for now (can be replaced with user input)
user_query_image = os.path.join(image_directory, user_query_image_name)
# Check if the provided query image exists in the directory
if not os.path.exists(user_query_image):
print(f"Error: The image '{user_query_image_name}' was not found in the specified directory.")
return
# Generate query image embedding
query_embedding = generate_image_embedding(user_query_image, model)
# Query FAISS index for top-2 similar images (k=2)
k = 2 # Set k to 2 to get the top two matches
distances, indices = query_faiss_index(query_embedding, index, k=k)
# Retrieve top-2 images and their metadata
top_k_images = [image.load_img(image_paths[i], target_size=(224, 224)) for i in indices[0]]
top_k_embeddings = [embeddings[i] for i in indices[0]]
top_k_indices = indices[0]
# Display the query image
query_image = image.load_img(user_query_image, target_size=(224, 224))
plt.imshow(query_image)
plt.title(f"User Query: {user_query_image_name}")
plt.axis('off')
plt.show()
# Display predicted top-2 images, their index and embeddings
predicted_image_data = [(top_k_indices[i], os.path.basename(image_paths[top_k_indices[i]]), top_k_embeddings[i][:5]) for i in range(k)]
display_image_table(predicted_image_data, f"Predicted Top-2 Images and Embeddings for: {user_query_image_name}")
# Display the top-2 images
display_images(top_k_images, [f"Index: {i}" for i in top_k_indices], top_k_embeddings, top_k_indices)
if __name__ == "__main__":
main()
1/1 [==============================] - 1s 752ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 50ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 50ms/step 1/1 [==============================] - 0s 49ms/step 1/1 [==============================] - 0s 48ms/step 1/1 [==============================] - 0s 47ms/step 1/1 [==============================] - 0s 47ms/step 1/1 [==============================] - 0s 51ms/step 1/1 [==============================] - 0s 48ms/step Stored Images and Embeddings | Index | Image | Embedding | |--------:|:-------------------|:---------------------------------------------------------| | 0 | airplane.jpg | [0.01681192 1.3374338 0.4865353 0.02957397 0.2203022 ] | | 1 | car.jpg | [0.02365123 0.21462531 0.39598972 0.04758341 0.13171911] | | 2 | elephant1.png | [0.30503857 2.5281742 0.13030702 0.40003565 0.8225097 ] | | 3 | elephant3.png | [0.62227046 1.1155356 0.10596917 0.03490202 0.40784183] | | 4 | elephant_face2.png | [0.03081354 2.703372 0.10826059 0.06344301 1.4554362 ] | | 5 | fighter_jet1.png | [2.3893921 0.37274185 0.21137363 0.01661964 0. ] | | 6 | fighter_jet2.png | [1.2492405 0.21738124 1.0315228 0. 0.23579091] | | 7 | fighter_jet3.png | [0.94846344 0.21674678 1.0023813 0. 0.11890657] | | 8 | kangaroo1.png | [0.89772236 0.94619936 0.5298464 0.25724941 0.05652738] | | 9 | kangaroo2.png | [0.869192 0.6844135 0.53440905 0.09354267 1.3998648 ] | | 10 | kangaroo3.png | [0.2655031 0.6898773 0.22355784 0.08943271 0.07943194] | | 11 | koala1.png | [1.0157447 1.2002645 0.07522814 0.25566643 0. ] | | 12 | koala2.png | [0.44358256 2.0638452 0.25086418 1.0025302 0.00942492] | | 13 | koala3.png | [1.1000643 0.8675375 0.29340845 0.02682502 0.2982855 ] | | 14 | koala4.png | [0.44358256 2.0638452 0.25086418 1.0025302 0.00942492] | 1/1 [==============================] - 0s 48ms/step
Predicted Top-2 Images and Embeddings for: elephant1.png | Index | Image | Embedding | |--------:|:-------------------|:---------------------------------------------------------| | 2 | elephant1.png | [0.30503857 2.5281742 0.13030702 0.40003565 0.8225097 ] | | 4 | elephant_face2.png | [0.03081354 2.703372 0.10826059 0.06344301 1.4554362 ] |